{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "76fba87d",
   "metadata": {},
   "source": [
    "| [02_lexical_analysis/02_Bert实体识别.ipynb](https://github.com/shibing624/nlp-tutorial/blob/main/07_information_extraction/02_Bert实体识别.ipynb)  | Bert实体识别  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/07_information_extraction/02_实体识别.ipynb) |\n",
    "\n",
    "# Bert实体识别\n",
    "适用于品牌、人名、地址名称识别，序列标注任务解决方法。\n",
    "\n",
    "基于transformers的预训练模型，识别人名实体，模型地址：https://huggingface.co/models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7da5eba5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6c650d4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at hfl/chinese-macbert-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'entity': 'LABEL_1', 'score': 0.60737634, 'index': 1, 'word': '王', 'start': 0, 'end': 1}, {'entity': 'LABEL_1', 'score': 0.6435883, 'index': 2, 'word': '宏', 'start': 1, 'end': 2}, {'entity': 'LABEL_1', 'score': 0.62788045, 'index': 3, 'word': '伟', 'start': 2, 'end': 3}, {'entity': 'LABEL_1', 'score': 0.58248323, 'index': 4, 'word': '来', 'start': 3, 'end': 4}, {'entity': 'LABEL_1', 'score': 0.5865929, 'index': 5, 'word': '自', 'start': 4, 'end': 5}, {'entity': 'LABEL_1', 'score': 0.70053256, 'index': 6, 'word': '北', 'start': 5, 'end': 6}, {'entity': 'LABEL_1', 'score': 0.7213193, 'index': 7, 'word': '京', 'start': 6, 'end': 7}, {'entity': 'LABEL_1', 'score': 0.7275812, 'index': 8, 'word': '，', 'start': 7, 'end': 8}, {'entity': 'LABEL_1', 'score': 0.65218896, 'index': 9, 'word': '是', 'start': 8, 'end': 9}, {'entity': 'LABEL_1', 'score': 0.5602354, 'index': 10, 'word': '个', 'start': 9, 'end': 10}, {'entity': 'LABEL_1', 'score': 0.5996493, 'index': 11, 'word': '警', 'start': 10, 'end': 11}, {'entity': 'LABEL_1', 'score': 0.6554733, 'index': 12, 'word': '察', 'start': 11, 'end': 12}, {'entity': 'LABEL_1', 'score': 0.6601639, 'index': 13, 'word': '，', 'start': 12, 'end': 13}, {'entity': 'LABEL_1', 'score': 0.5955751, 'index': 14, 'word': '喜', 'start': 13, 'end': 14}, {'entity': 'LABEL_1', 'score': 0.66636354, 'index': 15, 'word': '欢', 'start': 14, 'end': 15}, {'entity': 'LABEL_1', 'score': 0.5169803, 'index': 16, 'word': '去', 'start': 15, 'end': 16}, {'entity': 'LABEL_0', 'score': 0.5392316, 'index': 17, 'word': '王', 'start': 16, 'end': 17}, {'entity': 'LABEL_1', 'score': 0.524831, 'index': 18, 'word': '府', 'start': 17, 'end': 18}, {'entity': 'LABEL_1', 'score': 0.69037735, 'index': 19, 'word': '井', 'start': 18, 'end': 19}, {'entity': 'LABEL_1', 'score': 0.63688314, 'index': 20, 'word': '游', 'start': 19, 'end': 20}, {'entity': 'LABEL_1', 'score': 0.501772, 'index': 21, 'word': '玩', 'start': 20, 'end': 21}, {'entity': 'LABEL_0', 'score': 0.5587811, 'index': 22, 'word': '儿', 'start': 21, 'end': 22}, {'entity': 'LABEL_1', 'score': 0.5066337, 'index': 23, 'word': '。', 'start': 22, 'end': 23}]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModelForTokenClassification, AutoTokenizer\n",
    "from transformers import pipeline\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "model_name = \"hfl/chinese-macbert-base\"\n",
    "\n",
    "nlp = pipeline(\"ner\", model=model_name, tokenizer=model_name)\n",
    "sequence = \"王宏伟来自北京，是个警察，喜欢去王府井游玩儿。\"\n",
    "print(nlp(sequence))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04fff3bd",
   "metadata": {},
   "source": [
    "不使用pipline，自己编写NER任务预测代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fffd3640",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at hfl/chinese-macbert-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('[CLS]', 'O'), ('王', 'O'), ('宏', 'O'), ('伟', 'O'), ('来', 'B-PER'), ('自', 'B-PER'), ('北', 'B-PER'), ('京', 'B-PER'), ('，', 'B-PER'), ('是', 'O'), ('个', 'B-PER'), ('警', 'B-PER'), ('察', 'B-PER'), ('，', 'B-PER'), ('喜', 'B-PER'), ('欢', 'B-PER'), ('去', 'B-PER'), ('王', 'B-PER'), ('府', 'B-PER'), ('井', 'O'), ('游', 'B-PER'), ('玩', 'B-PER'), ('儿', 'B-PER'), ('。', 'B-PER'), ('[SEP]', 'O')]\n"
     ]
    }
   ],
   "source": [
    "# model = AutoModelForTokenClassification.from_pretrained(\"dbmdz/bert-large-cased-finetuned-conll03-english\")\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
    "model = AutoModelForTokenClassification.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "label_list = [\n",
    "    \"O\",  # Outside of a named entity\n",
    "    \"B-PER\",  # Beginning of a person's name right after another person's name\n",
    "    \"I-PER\",  # Person's name\n",
    "    \"B-ORG\",  # Beginning of an organisation right after another organisation\n",
    "    \"I-ORG\",  # Organisation\n",
    "    \"B-LOC\",  # Beginning of a location right after another location\n",
    "    \"I-LOC\"  # Location\n",
    "]\n",
    "sequence = \"王宏伟来自北京，是个警察，喜欢去王府井游玩儿。\"\n",
    "# Bit of a hack to get the tokens with the special tokens\n",
    "tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))\n",
    "inputs = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
    "outputs = model(inputs).logits\n",
    "predictions = torch.argmax(outputs, dim=2)\n",
    "print([(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].numpy())])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b74bc373",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700b6906",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}